import numpy as np
import pickle


def load_data():
    """
    加载数据集
    """
    with open('决策树数据.txt') as fp:
        lines = fp.readlines()

    x = np.empty((len(lines), 7), dtype=int)

    for i in range(len(lines)):
        line = lines[i].strip().split(',')
        x[i] = line

    test_x = x[10:]
    x = x[:10]
    return x, test_x


def get_entropy(_x):
    """
    计算数据集的熵,这个熵是针对y来说的
    """
    entropy = 0
    # 统计y的熵
    y = _x[:, -1]
    # 统计每个结果出现的次数,[5,5]表示0出现5次,1出现5次
    bincount = np.bincount(y)
    for count in bincount:
        if count == 0:
            continue
        # 出现次数 / 总次数 = 出现概率
        prob = count / len(_x)
        # 熵 = prob * log(prob) * -1
        entropy -= prob * np.log2(prob)
    return entropy


def get_gain(_x, col):
    """
    id3决策树
    计算某一列的信息增益
    """
    # 列熵
    col_entropy = 0
    # 根据列值把数据分成n份
    for value in set(_x[:, col]):
        x_by_col_and_value = _x[_x[:, col] == value]
        # 这个数据子集出现的概率等于 出现次数/ 总次数
        prob = len(x_by_col_and_value) / len(_x)
        # 求数据子集的熵
        entropy = get_entropy(x_by_col_and_value)
        # 列的熵等于这个式子的累计
        col_entropy += prob * entropy
    # 信息增益就是切分数据后,熵值能下降多少,这个值越大越好
    gain = get_entropy(_x) - col_entropy
    # 用这个就是id3决策树,他倾向于可取值多的列
    return gain


def get_gain_c45(_x, col):
    """
    c4.5决策树
    计算某一列的信息增益
    """
    # 列熵
    col_entropy = 0
    # 防止除0
    iv = 1e-20
    # 根据列值把数据分成n份
    for value in set(_x[:, col]):
        x_by_col_and_value = _x[_x[:, col] == value]
        # 这个数据子集出现的概率等于 出现次数/ 总次数
        prob = len(x_by_col_and_value) / len(_x)
        # 求数据子集的熵
        entropy = get_entropy(x_by_col_and_value)
        # 列的熵等于这个式子的累计
        col_entropy += prob * entropy
        iv -= prob * np.log2(prob)
    # 信息增益就是切分数据后,熵值能下降多少,这个值越大越好
    gain = get_entropy(_x) - col_entropy
    # 用这个就是c4.5决策树,他解决了取值多的列更容易被选择的问题
    return gain / iv


def get_split_col(_x):
    """
    获取信息增益最大的那一列
    """
    best_col = -1
    best_gain = 0
    # 遍历所有的列,最后一列是y,不需要计算
    for col in range(_x.shape[1] - 1):
        # 信息增益就是切分数据后,熵值能下降多少,这个值越大越好
        gain = get_gain(_x, col)
        # 信息增益最大的列,就是要被拆分的列
        if gain > best_gain:
            best_gain = gain
            best_col = col
    return best_col


class Node:
    """
    分支节点
    """

    def __init__(self, col):
        self.col = col
        self.children = {}

    def __str__(self) -> str:
        return 'Node col=%d' % self.col


class Leaf:
    """
    叶子节点
    """

    def __init__(self, y):
        self.y = y

    def __str__(self) -> str:
        return 'Leaf y=%d' % self.y


def print_tree(node, prefix='', subfix=''):
    """
    打印树的方法
    """
    prefix += '-'*4
    print(prefix, node, subfix)
    if isinstance(node, Leaf):
        return
    for i in node.children:
        subfix = 'value='+str(i)
        print_tree(node.children[i], prefix, subfix)


def create_children(_x, parent_node):
    """
    添加子节点的方法
    """
    # 遍历父节点col列所有的取值
    for split_value in np.unique(_x[:, parent_node.col]):

        # 首先根据父节点col列的取值分割数据
        sub_x = _x[_x[:, parent_node.col] == split_value]

        # 取去重y值
        unique_y = np.unique(sub_x[:, -1])

        # 如果所有的y都是一样的,说明是个叶子节点
        if len(unique_y) == 1:
            parent_node.children[split_value] = Leaf(unique_y[0])
            continue

        # 否则,是个分支节点,计算最佳切分列
        split_col = get_split_col(sub_x)

        # 添加分支节点到父节点上
        parent_node.children[split_value] = Node(col=split_col)


def pred(_x, node):
    """
    预测方法
    """
    col_value = _x[node.col]
    node = node.children[col_value]
    if isinstance(node, Leaf):
        return node.y
    return pred(_x, node)


x, test_x = load_data()
# 先获取到在所有数据上信息增益最大的列
col = get_split_col(x)
# 创建根节点
root = Node(col)
# 创建子节点
create_children(x, root)
print_tree(root)
print('*'*100)

# 继续创建,0=0节点的下一层
x_0_0 = x[x[:, 0] == 0]
create_children(x_0_0, root.children[0])
print_tree(root)
print('*'*100)

# 继续创建,0=1的下一层
x_0_1 = x[x[:, 0] == 1]
create_children(x_0_1, root.children[1])
print_tree(root)
print('*'*100)

# 继续创建,0=1,1=1的下一层
x_0_1_and_1_1 = x_0_1[x_0_1[:, 1] == 1]
create_children(x_0_1_and_1_1, root.children[1].children[1])
print_tree(root)
print('*'*100)

correct = 0
for i in x:
    if pred(i, root) == i[-1]:
        correct += 1
print('训练集', correct / len(x))
print('*'*100)
correct = 0
for i in test_x:
    if pred(i, root) == i[-1]:
        correct += 1
print('测试集', correct / len(test_x))

# 序列化保存这颗树,后面剪枝用
with open('tree.dump', 'wb') as fp:
    pickle.dump(root, fp)
